from axelrod.action import Action

from axelrod.player import Player

C, D = Action.C, Action.D

def create_policy(pCC, pCD, pDC, pDD):
    """
    Creates a dict that represents a Policy. As defined in the reference, a
    Policy is a set of (prev_move, p) where p is the probability to cooperate
    after prev_move, where prev_move can be (C, C), (C, D), (D, C) or (D, D).

    Parameters

    pCC, pCD, pDC, pDD : float
        Must be between 0 and 1.
    """
    return {(C, C): pCC, (C, D): pCD, (D, C): pDC, (D, D): pDD}

def action_to_int(action):
    if action == C:
        return 1
    return 0

def minimax_tree_search(begin_node, policy, max_depth):
    """
    Tree search function (minimax search procedure) for the tree (built by
    recursion) corresponding to the opponent's policy, and solves it.
    Returns a tuple of two floats that are the utility of playing C, and the
    utility of playing D.
    """
    if begin_node.is_stochastic():
        # A stochastic node cannot have the same depth than its parent node
        # hence there is no need to check that its depth is < max_depth.
        siblings = begin_node.get_siblings()
        # The stochastic node value is the expected value of siblings.
        node_value = begin_node.pC * minimax_tree_search(
            siblings[0], policy, max_depth
        ) + (1 - begin_node.pC) * minimax_tree_search(
            siblings[1], policy, max_depth
        )
        return node_value
    else:  # Deterministic node
        if begin_node.depth == max_depth:
            # This is an end node, we just return its outcome value.
            return begin_node.get_value()
        elif begin_node.depth == 0:
            siblings = begin_node.get_siblings(policy)
            # This returns the two max expected values, for choice C or D,
            # as a tuple.
            return (
                minimax_tree_search(siblings[0], policy, max_depth)
                + begin_node.get_value(),
                minimax_tree_search(siblings[1], policy, max_depth)
                + begin_node.get_value(),
            )
        elif begin_node.depth < max_depth:
            siblings = begin_node.get_siblings(policy)
            # The deterministic node value is the max of both siblings values
            # + the score of the outcome of the node.
            a = minimax_tree_search(siblings[0], policy, max_depth)
            b = minimax_tree_search(siblings[1], policy, max_depth)
            node_value = max(a, b) + begin_node.get_value()
            return node_value

def move_gen(outcome, policy, depth_search_tree=5):
    """
    Returns the best move considering opponent's policy and last move,
    using tree-search procedure.
    """
    current_node = DeterministicNode(outcome[0], outcome[1], depth=0)
    values_of_choices = minimax_tree_search(
        current_node, policy, depth_search_tree
    )
    # Returns the Action which correspond to the best choice in terms of
    # expected value. In case value(C) == value(D), returns C.
    actions_tuple = (C, D)
    return actions_tuple[values_of_choices.index(max(values_of_choices))]

class Node(object):
    """
    Nodes used to build a tree for the tree-search procedure. The tree has
    Deterministic and Stochastic nodes, as the opponent's strategy is learned
    as a probability distribution.
    """

    # abstract method
    def get_siblings(self):
        raise NotImplementedError("subclasses must override get_siblings()!")

    # abstract method
    def is_stochastic(self):
        raise NotImplementedError("subclasses must override is_stochastic()!")

class DeterministicNode(Node):
    """
    Nodes (C, C), (C, D), (D, C), or (D, D) with deterministic choice
    for siblings.
    """

    def __init__(self, action1, action2, depth):
        self.action1 = action1
        self.action2 = action2
        self.depth = depth

    def get_siblings(self, policy):
        """
        Returns the siblings node of the current DeterministicNode. Builds 2
        siblings (C, X) and (D, X) that are StochasticNodes. Those siblings are
        of the same depth as the current node. Their probabilities pC are
        defined by the policy argument.
        """
        c_choice = StochasticNode(
            C, policy[(self.action1, self.action2)], self.depth
        )
        d_choice = StochasticNode(
            D, policy[(self.action1, self.action2)], self.depth
        )
        return c_choice, d_choice

    def is_stochastic(self):
        """Returns True if self is a StochasticNode."""
        return False

    def get_value(self):
        values = {(C, C): 3, (C, D): 0, (D, C): 5, (D, D): 1}
        return values[(self.action1, self.action2)]